Seeing is believing

Using FlashTorch 🔦 to shine a light on what neural nets "see"


by Misa Ogura

Hello, I'm Misa 👋


  • Originally from Tokyo, now based in London
  • Cancer Cell Biologist, turned Software Engineer
  • Currently at BBC R&D
  • Co-founder of Women Driven Development
  • Women in Data Science London Ambassador

Feature visualisation


Image convolution & CNN 101


Kernel & convolution


Kernel: a small matrix used for blurring, sharpening, embossing, edge detection etc

Convolution: an operation to calculate weighted sum of local neibours

Example of convolution: detecting edges


Typical CNN architecture


Kernels weights are learnt during the training to extract relevant features from input images.

Feature visualisation technique

Saliency maps


Saliency


In human visual perception: a subjective quality that makes certain items stand out and grab our attention

In computer vision: an indication of the most “salient” regions of an image

Saliemcy maps in CNNs


  • First introduced in 2013

  • Compute the gradient of output category (target class) w.r.t. input image

  • Positive gradients: a small change to the pixel will increase the probability of target class

  • Visualising the gradients provide some intuition of attention

FlashTorch demo 1

Visualising saliency maps with backpropagation


Install FlashTorch & load an image


First things first...

$ pip install flashtorch
In [2]:
from flashtorch.utils import load_image

image = load_image('../../examples/images/great_grey_owl_01.jpg')

plt.imshow(image)
plt.title('Original image')
plt.axis('off');

Apply transformations


In [3]:
from flashtorch.utils import apply_transforms, denormalize, format_for_plotting

input_ = apply_transforms(image)

print(f'Before: {type(image)}')
print(f'After: {type(input_)}, {input_.shape}')

plt.imshow(format_for_plotting(denormalize(input_)))
plt.title('Input tensor')
plt.axis('off');
Before: <class 'PIL.Image.Image'>
After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])

Create a Backprop object with a pre-trained model


In [4]:
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)

backprop = Backprop(model)
Signature:

    backprop.calculate_gradients(input_, target_class=None, take_max=False)

Calculate the gradients of target class w.r.t. input


In [5]:
from flashtorch.utils import ImageNetIndex 

imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']

print(f'Traget class index: {target_class}')

gradients = backprop.calculate_gradients(input_, target_class)

max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

print(type(gradients), gradients.shape)
print(type(max_gradients), max_gradients.shape)
Traget class index: 24
<class 'torch.Tensor'> torch.Size([3, 224, 224])
<class 'torch.Tensor'> torch.Size([1, 224, 224])

Let's visualise gradients


In [6]:
from flashtorch.utils import visualize

visualize(input_, gradients, max_gradients)
Pixels where the animal is present have the strongest positive effects.
But it's quite noisy...

FlashTorch demo 2

Visualising saliency maps with guided backpropagation


Guided backpropagation


  • Masks out pixels for which at least one of ReLU activation values or gradients is negative

  • Ignores pixels that didn't contribute to the prediction i.e. ReLU activation is zero

  • Ignores pixels that have negative gradients

Calculate the gradients with guided backprop


In [7]:
guided_gradients = backprop.calculate_gradients(input_, target_class, guided=True)

max_guided_gradients = backprop.calculate_gradients(input_, target_class, take_max=True, guided=True)

visualize(input_, guided_gradients, max_guided_gradients)
Now that's much less noisy!
Pixels around the head and eyes have the strongest positive effects.

What about a jay?


In [9]:
visualize(input_, guided_gradients, max_guided_gradients)

Or an oystercatcher?


In [11]:
visualize(input_, guided_gradients, max_guided_gradients)

FlashTorch demo 3

Gaining additional insights on transfer learning


Transfer learning


  • A model developed for a task is reused as a starting point for another task

  • Pre-trained models often used in computer visions & natural language processing tasks

  • Save compute & time resources

Building a flower classifier


<-- From: Densenet model, pre-trained on ImageNet (1000 classes)

--> To: Flower classifier to recognise 102 species of flowers, using a dataset from VGG group.

Load a target image


In [12]:
image = load_image('../../examples/images/foxglove.jpg')

plt.imshow(image)
plt.title('Foxglove')
plt.axis('off');

Pre-trained model (no additional training!)


In [15]:
backprop = Backprop(pretrained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:93: UserWarning: The predicted class does not equal the
                target class. Calculating the gradient with respect to the
                predicted class.
  predicted class.'''))

Trained model


In [16]:
backprop = Backprop(trained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
Trained model pays attention specifically to the distinguising pattern of this particular specie.

Thank you!